from __future__ import annotations
import os
import re
import json
import random
import gensim
import openai
import spacy
import torch
import numpy as np
from openai import OpenAI
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from torch.nn import functional as F
from nltk import pos_tag
from nltk.stem import WordNetLemmatizer
from transformers import AutoTokenizer, AutoModel, logging

from utils.util import read_json, read_txt, write_json
from utils.token_count_decorator import token_count_decorator
from planning.src.protocol import Protocol
from planning.data_process.rag_retrieval_generation import Retrieval

logging.set_verbosity_error()

class Representer:
    def __init__(self, domain) -> None:
        self.domain = domain
        self.retrieval = Retrieval(domain)
        self.data_path = "planning/data/"
        self.corpus_path = f"planning/data/corpus/{self.domain}/"
        self.operation_dsl, self.production_dsl = self.load_dsl()
        print("dsl load")
        self.word2vec_model = gensim.models.KeyedVectors.load_word2vec_format("dataset/GoogleNews-vectors-negative300.bin.gz", binary=True)
        print("word2vec load")
        self.lemmatizer = WordNetLemmatizer()
        self.nlp = spacy.load("en_core_web_trf")
        print("en_core_web_trf load")
        self.tokenizer = AutoTokenizer.from_pretrained("allenai/scibert_scivocab_uncased")
        self.model = AutoModel.from_pretrained("allenai/scibert_scivocab_uncased")
        self.component_embeddings = self.__load_component_embedding(matrix_path=os.path.join(self.data_path, f"{domain}_component_emb.npy"))
        self.operation_extraction_prompt = read_txt("planning/data/prompt/operation_extraction.txt")
        self.flowunit_extraction_prompt = read_txt("planning/data/prompt/flowunit_extraction.txt")
        self.program_components_extraction_prompt = read_txt("planning/data/prompt/program_components_extraction.txt")
        self.pseudofunctions_generation_prompt = read_txt("planning/data/prompt/pseudofunctions_generation.txt")
        self.dataset_metadata_path = "planning/data/dataset_metadata.json"
        self.dataset_metadata = read_json(self.dataset_metadata_path)
    
    def load_dsl(self):
        operation_dsl = read_json(f"dsl_result/{self.domain}/operation_dsl.json")
        production_dsl = read_json(f"dsl_result/{self.domain}/production_dsl.json")
        return operation_dsl, production_dsl
    
    def represent(self, protocol: Protocol, mode):
        '''
        Get representations for planning for the input protocol.

        Args:
            mode (str):The mode of representations for planning. Possible values are:
                - 'flatten': Full procedure of similar protocols from corpus.
                - 'atomic': Pseudofunctions of similar protocols from corpus.
                - 'atomic-internal': Pseudofunctions of similar protocols from corpus, including itself
                - 'dsl': operation DSL specifications of relevant operations
                - 'multi-dsl': Operation DSL specifications of relevant operations & Production DSL specifications of relevant flowunits
        '''
        if mode == "flatten":
            return self.__fetch_flatten_representation(protocol)
        
        elif mode == "atomic":
            pseudofunctions_str = self.__fetch_atomic_representation(protocol)
            return self.__split_shuffle_pseudofunctions(pseudofunctions_str)
        
        elif mode == "atomic-internal":
            pseudofunctions_str = self.__fetch_atomic_representation(protocol, internal=True)
            return self.__split_shuffle_pseudofunctions(pseudofunctions_str)
        
        elif mode == "dsl":
            oper_repr, _ = self.__fetch_dsl_representation(protocol, mode="dsl")
            return json.dumps(oper_repr, indent=4, ensure_ascii=False)
        
        elif mode == "multi-dsl":
            oper_repr, prod_repr = self.__fetch_multi_dsl_representation(protocol)
            return (
                json.dumps(oper_repr, indent=4, ensure_ascii=False), 
                json.dumps(prod_repr, indent=4, ensure_ascii=False)
            )
        
        else:
            raise ValueError(f"Invaild mode: {mode}. Choose from ['flatten', 'atomic', 'atomic-internal', 'dsl', 'multi-dsl'].")

    def __fetch_flatten_representation(self, protocol: Protocol):
        representations = []
        ids = self.retrieval.run_query(query=protocol.title)
        # print("retrieval finished")
        for id in ids:
            sim_protocol = read_json(self.corpus_path+f"{id}.json")
            representations.append("\n".join(sim_protocol["procedures"]))
        return "\n".join(representations)
    
    def __fetch_atomic_representation(self, protocol: Protocol, internal=False):
        representations = []
        ids = self.retrieval.run_query(query=protocol.title)
        # print("retrieval finished")
        for id in ids:
            sim_protocol = read_json(self.corpus_path+f"{id}.json")
            if pseudocode := sim_protocol.get("generated_pseudocode"):
                representations.append(pseudocode.split("# Protocol steps")[0])
                continue
            prompt = self.pseudofunctions_generation_prompt.replace("{title}", sim_protocol["title"]).replace("{protocol}", "\n".join(sim_protocol["procedures"]))
            for i in range(5):
                # print(f"try {i}")
                response = self.__chatgpt_function(content=prompt)
                try:
                    pseudofunctions, pseudocode = response.split("# Protocol steps", 1)
                    if pseudofunctions and pseudocode:
                        representations.append(pseudofunctions)
                        sim_protocol["generated_pseudocode"] = response
                        write_json(self.corpus_path+f"{id}.json", sim_protocol)
                        break
                except:
                    continue
        if internal:
            representations.append(protocol.pseudofunctions)
        return "\n".join(representations)
    
    def __fetch_dsl_representation(self, protocol: Protocol, mode="dsl"):
        oper_repr = {}
        prod_repr = {}
        operations = []
        components = []
        if mode == "dsl":
            operation_sequence = self.__get_operations_sequence(protocol.dsl_program)
            flowunits = self.__get_flowunits(protocol.dsl_program)
        elif mode == "multi-dsl":
            operation_sequence = self.__get_operations_sequence(protocol.multi_dsl_program)
            flowunits = self.__get_flowunits(protocol.multi_dsl_program)
        for operation in operation_sequence:
            opcode = self.__similarity_opcode(operation)
            if "NONE" not in opcode:
                operations.append(opcode)
                if operation.lower() == opcode.lower():
                    oper_repr[opcode] = random.sample(self.operation_dsl[opcode], 1)
        for flowunit in flowunits:
            component = self.__similarity_component(flowunit)
            if "NONE" not in component:
                components.append(component)
                prod_repr[component] = self.production_dsl[component]
        return oper_repr, prod_repr
    
    def __fetch_multi_dsl_representation(self, protocol: Protocol):
        oper_repr = {}
        prod_repr = {}
        operations = []
        components = []
        devices = []
        operation_sequence = self.__get_operations_sequence(protocol.multi_dsl_program)
        flowunits = self.__get_flowunits(protocol.multi_dsl_program)
        devices = self.__get_devices(protocol.multi_dsl_program)
        for operation in operation_sequence:
            opcode = self.__similarity_opcode(operation)
            # print(f"operation: {operation}, opcode: {opcode}")
            if "NONE" not in opcode:
                operations.append(opcode)
                if operation.lower() == opcode.lower():
                    oper_repr[opcode] = random.sample(self.operation_dsl[opcode], 1)
        for flowunit in flowunits:
            component = self.__similarity_component(flowunit)
            # print(f"flowunit: {flowunit}, component: {component}")
            if "NONE" not in component:
                components.append(component)
                prod_repr[component] = self.production_dsl[component]
        return oper_repr, prod_repr

    def __split_shuffle_pseudofunctions(self, pseudofunctions_str):
        pseudofunctions = re.findall(r'def[\s\S]*?pass', pseudofunctions_str)
        unique_functions = list(set(pseudofunctions))
        random.shuffle(unique_functions)
        return "\n\n".join(unique_functions)

    def __convert_to_sentence_list(self, steps):
        if not steps:
            return []
        
        sentences = [sentence.strip() for sentence in steps.split("\n") if sentence.strip()]
        operation_steps = [sentence for sentence in sentences if re.match(r'^\d+\.', sentence)]
        return operation_steps

    def __operation_extraction(self, sentence):
        prompt = self.operation_extraction_prompt.replace("---SENTENCES---", sentence)
        for _ in range(5):
            response = self.__chatgpt_function(prompt).strip().upper()
            if "NONE" in response:
                return ["NONE"]
            operations = [
                self.lemmatizer.lemmatize(op.strip().lower(), pos="v") 
                for op in response.split(",") if op.strip()
            ]
            if all(sum(1 for token in self.nlp(op) if not token.is_punct) == 1 for op in operations):
                return operations
        return ["NONE"]

    def __flowunit_extraction(self, sentence):
        prompt = self.flowunit_extraction_prompt.replace("---SENTENCES---", sentence)
        for _ in range(5):
            response = self.__chatgpt_function(prompt).strip()
            if "NONE" in response:
                return ["NONE"]
            return [flowunit.strip() for flowunit in response.split(",") if flowunit.strip()]

    def __similarity_opcode(self, operation):
        if not operation or operation == "NONE":
            return "NONE"
        operation_lower = operation.lower()
        if operation_lower not in self.word2vec_model:
            return "NONE"
        closest_word = "NONE"
        max_similarity = -1
        for opcode in self.operation_dsl:
            opcode_lower = opcode.lower()
            if opcode_lower in self.word2vec_model:
                similarity = self.word2vec_model.similarity(operation_lower, opcode_lower)
                if similarity > max_similarity:
                    max_similarity = similarity
                    closest_word = opcode
        return closest_word
    
    def __similarity_component(self, flowunit):
        if not flowunit or flowunit == "NONE":
            return "NONE"
        components = sorted(self.production_dsl)
        flowunit_emb = self.__get_embedding(flowunit)
        cosine_vector = cosine_similarity([flowunit_emb], self.component_embeddings)
        closest_idx = np.argmax(cosine_vector, axis=1)[0]
        return components[closest_idx]
    
    def __load_component_embedding(self, matrix_path):
        if os.path.exists(matrix_path):
            embeddings = np.load(matrix_path)
        else:
            embeddings = np.array([self.__get_embedding(component) for component in tqdm(sorted(self.production_dsl), desc="Dump component enbeddings")])
            np.save(matrix_path, embeddings)
        return embeddings

    def __get_embedding(self, text):
        '''get embedding of CLS token'''
        inputs = self.tokenizer(text, return_tensors='pt')
        with torch.no_grad():
            outputs = self.model(**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]
        embedding = F.normalize(embedding, p=2, dim=1)
        return embedding.squeeze().numpy()

    @token_count_decorator(flow="together", batch=False)
    def __chatgpt_function(self, content, gpt_model="gpt-4o-mini"):
        while True:
            try:
                client = OpenAI(
                    api_key=os.environ.get("OPENAI_API_KEY"),
                )
                chat_completion = client.chat.completions.create(
                    messages=[
                        {"role": "user", "content": content}
                    ],
                    model=gpt_model
                )
                return chat_completion.choices[0].message.content
            except openai.APIError as error:
                print(error)
    
    def __get_operations_sequence(self, program: dict) -> list:
        operations_sequence = []
        for step_dic in program:
            if "Operation" in step_dic:
                first_verb = self.__get_first_verb(step_dic["Operation"])
                if first_verb:
                    operations_sequence.append(first_verb)
        return operations_sequence

    def __get_first_verb(self, operation_str):
        tokens = re.split(r'[_ ]', operation_str)
        lemmatized_tokens = [self.lemmatizer.lemmatize(token, pos="v").lower() for token in tokens]
        pos_tags = pos_tag(lemmatized_tokens)
        
        for word, pos in pos_tags:
            if pos.startswith('VB'):  # VB, VBD, VBG, VBN, VBP, VBZ
                return word
        return lemmatized_tokens[0]
    
    def __get_components(self, protocol: Protocol) -> list:
        flowunits = self.dataset_metadata[self.domain].get(protocol.id, {}).get("flowunits", [])
        if not flowunits:
            prompt = self.program_components_extraction_prompt.replace("---PSEUDOCODE---", json.dumps(protocol.program, indent=4, ensure_ascii=False))
            for _ in range(5):
                response = self.__chatgpt_function(prompt)
                flowunits = [flowunit.strip() for flowunit in response.split(",") if flowunit.strip()]
                if flowunits:
                    self.dataset_metadata[self.domain].setdefault(protocol.id, {})["flowunits"] = flowunits
                    self.__dump_dataset_metadata()
                    break
        return flowunits
    
    def __get_flowunits(self, program: dict):
        flowunits = []
        multi = any("FlowUnit" in step for step in program)
        for step in program:
            if multi and "FlowUnit" in step:
                flowunits.append(step["FlowUnit"]["Component"])
            elif not multi:
                try:
                    flowunits.extend(step["Precond"]["SlotArg"])
                except:
                    continue
        return flowunits

    def __get_devices(self, program: dict) -> list:
        devices = []
        for step in program:
            if "Execution" in step:
                if isinstance(step["Execution"], dict):
                    devices.append(step["Execution"]["DeviceType"])
                elif isinstance(step["Execution"], list):
                    devices.extend([device_dict["DeviceType"] for device_dict in step["Execution"]])
        return devices

    def __dump_dataset_metadata(self):
        write_json(self.dataset_metadata_path, self.dataset_metadata)